{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "!pip install wget\n",
    "!pip install nemo_toolkit[tts]\n",
    "\n",
    "!mkdir configs\n",
    "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/master/examples/tts/configs/tacotron2.yaml\n",
    "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/master/examples/tts/configs/waveglow.yaml\n",
    "CONFIG_PATH = \"configs/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import math\n",
    "import os\n",
    "import copy\n",
    "import shutil\n",
    "import librosa\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "from scipy.io.wavfile import write\n",
    "import numpy as np\n",
    "import IPython.display as ipd\n",
    "\n",
    "from ruamel.yaml import YAML\n",
    "\n",
    "import torch\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr\n",
    "import nemo.collections.tts as nemo_tts\n",
    "import nemo.utils.argparse as nm_argparse\n",
    "\n",
    "logging = nemo.logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try: CONFIG_PATH\n",
    "except NameError: CONFIG_PATH = os.path.join(\"..\", \"configs\")\n",
    "\n",
    "# Download config files\n",
    "config_path = os.path.join(CONFIG_PATH, 'tacotron2.yaml')\n",
    "waveglow_config_path = os.path.join(CONFIG_PATH, 'waveglow.yaml')\n",
    "\n",
    "yaml = YAML(typ=\"safe\")\n",
    "with open(config_path) as file:\n",
    "    tacotron2_config = yaml.load(file)\n",
    "    labels = tacotron2_config[\"labels\"]\n",
    "    \n",
    "with open(waveglow_config_path) as file:\n",
    "    waveglow_config = yaml.load(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download pre-trained checkpoints\n",
    "\n",
    "Note: The checkpoint for WaveGlow is very large (>1GB), so please ensure you have sufficient storage space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_checkpoint_path = './checkpoints/'\n",
    "WAVEGLOW = os.path.join(base_checkpoint_path, 'WaveGlowNM.pt')\n",
    "TACOTRON_ENCODER = os.path.join(base_checkpoint_path, 'Tacotron2Encoder.pt')\n",
    "TACOTRON_DECODER = os.path.join(base_checkpoint_path, 'Tacotron2Decoder.pt')\n",
    "TACOTRON_POSTNET = os.path.join(base_checkpoint_path, 'Tacotron2Postnet.pt')\n",
    "TEXT_EMBEDDING = os.path.join(base_checkpoint_path, 'TextEmbedding.pt')\n",
    "\n",
    "if not os.path.exists(base_checkpoint_path):\n",
    "    os.makedirs(base_checkpoint_path)\n",
    "    \n",
    "if not os.path.exists(WAVEGLOW):\n",
    "    !wget wget https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ljspeech/versions/2/files/WaveGlowNM.pt -P {base_checkpoint_path};\n",
    "\n",
    "if not os.path.exists(TACOTRON_ENCODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Encoder.pt -P {base_checkpoint_path};\n",
    "        \n",
    "if not os.path.exists(TACOTRON_DECODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Decoder.pt -P {base_checkpoint_path};\n",
    "\n",
    "if not os.path.exists(TACOTRON_POSTNET):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/Tacotron2Postnet.pt -P {base_checkpoint_path};\n",
    "\n",
    "if not os.path.exists(TEXT_EMBEDDING):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2_ljspeech/versions/2/files/TextEmbedding.pt -P {base_checkpoint_path};\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare the Neural Factory\n",
    "neural_factory = nemo.core.NeuralModuleFactory(\n",
    "        optimization_level=\"O0\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text Line Data Layer\n",
    "\n",
    "Construct a simple datalayer to load a single line of text (accepted from the user) and pass it to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.backends.pytorch import DataLayerNM\n",
    "from nemo.core.neural_types import *\n",
    "from nemo.utils.misc import pad_to\n",
    "from nemo.collections.asr.parts.dataset import TranscriptDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SentenceDataLayer(DataLayerNM):\n",
    "    \"\"\"A simple Neural Module for loading textual transcript data.\n",
    "    The path, labels, and eos_id arguments are dataset parameters.\n",
    "\n",
    "    Args:\n",
    "        pad_id (int): Label position of padding symbol\n",
    "        batch_size (int): Size of batches to generate in data loader\n",
    "        drop_last (bool): Whether we drop last (possibly) incomplete batch.\n",
    "            Defaults to False.\n",
    "        num_workers (int): Number of processes to work on data loading (0 for\n",
    "            just main process).\n",
    "            Defaults to 0.\n",
    "    \"\"\"\n",
    "\n",
    "    @property\n",
    "    def output_ports(self):\n",
    "        \"\"\"Returns definitions of module output ports.\n",
    "\n",
    "        texts:\n",
    "            0: AxisType(BatchTag)\n",
    "\n",
    "            1: AxisType(TimeTag)\n",
    "\n",
    "        texts_length:\n",
    "            0: AxisType(BatchTag)\n",
    "\n",
    "        \"\"\"\n",
    "        return {\n",
    "            'texts': NeuralType(('B', 'T'), LabelsType()),\n",
    "            'texts_length': NeuralType(tuple('B'), LengthsType()),\n",
    "        }\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        path,\n",
    "        labels,\n",
    "        batch_size,\n",
    "        bos_id=None,\n",
    "        eos_id=None,\n",
    "        pad_id=None,\n",
    "        drop_last=False,\n",
    "        num_workers=0,\n",
    "        shuffle=True,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        # Set up dataset\n",
    "        self.dataset_params = {\n",
    "            'path': path,\n",
    "            'labels': labels,\n",
    "            'bos_id': bos_id,\n",
    "            'eos_id': eos_id,\n",
    "        }\n",
    "\n",
    "        self._dataset = TranscriptDataset(**self.dataset_params)\n",
    "\n",
    "        # Set up data loader\n",
    "        sampler = None\n",
    "        pad_id = 0 if pad_id is None else pad_id\n",
    "        \n",
    "    def update_dataset(self):\n",
    "        self._dataset = TranscriptDataset(**self.dataset_params)\n",
    "        logging.info('Dataset updated.')\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._dataset)\n",
    "\n",
    "    @property\n",
    "    def dataset(self):\n",
    "        return self._dataset\n",
    "\n",
    "    @property\n",
    "    def data_iterator(self):\n",
    "        return None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create the Tacotron 2 + WaveGlow Neural Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_NMs(tacotron2_config, waveglow_config, labels, decoder_infer=False, waveglow_sigma=0.6):\n",
    "    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(\n",
    "        **tacotron2_config[\"AudioToMelSpectrogramPreprocessor\"][\"init_params\"]\n",
    "    )\n",
    "    \n",
    "    text_embedding_params = copy.deepcopy(tacotron2_config[\"TextEmbedding\"][\"init_params\"])\n",
    "    text_embedding_params['n_symbols'] = len(labels) + 3\n",
    "    \n",
    "    # Load checkpoint for text embedding\n",
    "    text_embedding = nemo_tts.TextEmbedding(**text_embedding_params)\n",
    "    text_embedding.restore_from(TEXT_EMBEDDING)\n",
    "    \n",
    "    # Load checkpoint for encoder\n",
    "    t2_enc = nemo_tts.Tacotron2Encoder(**tacotron2_config[\"Tacotron2Encoder\"][\"init_params\"])\n",
    "    t2_enc.restore_from(TACOTRON_ENCODER)\n",
    "    \n",
    "    # Load checkpoint for decoder\n",
    "    decoder_params = copy.deepcopy(tacotron2_config[\"Tacotron2Decoder\"][\"init_params\"])\n",
    "    \n",
    "    t2_dec = nemo_tts.Tacotron2DecoderInfer(**decoder_params)    \n",
    "    t2_dec.restore_from(TACOTRON_DECODER)\n",
    "        \n",
    "    # Load checkpoint for PortNet\n",
    "    t2_postnet = nemo_tts.Tacotron2Postnet(**tacotron2_config[\"Tacotron2Postnet\"][\"init_params\"])\n",
    "    t2_postnet.restore_from(TACOTRON_POSTNET)\n",
    "    \n",
    "    t2_loss = nemo_tts.Tacotron2Loss(**tacotron2_config[\"Tacotron2Loss\"][\"init_params\"])\n",
    "    \n",
    "    makegatetarget = nemo_tts.MakeGate()\n",
    "\n",
    "    total_weights = text_embedding.num_weights + t2_enc.num_weights + t2_dec.num_weights + t2_postnet.num_weights\n",
    "\n",
    "    logging.info('================================')\n",
    "    logging.info(f\"Total number of parameters (Tacotron 2): {total_weights}\")\n",
    "    logging.info('================================')\n",
    "    \n",
    "    \n",
    "    # Load WaveGlow model\n",
    "    waveglow_args = copy.deepcopy(waveglow_config[\"WaveGlowNM\"][\"init_params\"])\n",
    "    waveglow_args['sigma'] = waveglow_sigma\n",
    "    \n",
    "    waveglow = nemo_tts.WaveGlowInferNM(**waveglow_args)\n",
    "    waveglow.restore_from(WAVEGLOW)\n",
    "    \n",
    "    total_weights = waveglow.num_weights\n",
    "    \n",
    "    logging.info('================================')\n",
    "    logging.info(f\"Total number of parameters (WaveGlow): {total_weights}\")\n",
    "    logging.info('================================')\n",
    "\n",
    "    return (\n",
    "        data_preprocessor,\n",
    "        text_embedding,\n",
    "        t2_enc,\n",
    "        t2_dec,\n",
    "        t2_postnet,\n",
    "        t2_loss,\n",
    "        makegatetarget,\n",
    "    ), waveglow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neural_modules, waveglow = create_NMs(tacotron2_config, waveglow_config, labels, decoder_infer=True, waveglow_sigma=0.6);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Utility functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_text(text):\n",
    "    if not os.path.exists('cache/'):\n",
    "        os.makedirs('cache/')\n",
    "        \n",
    "    fp = os.path.join('cache', 'input.txt')\n",
    "    with open(fp, 'w', encoding='utf8') as f:\n",
    "        f.write('{}\\n'.format(text))\n",
    "        f.flush()\n",
    "    \n",
    "    logging.info(\"Updated input file with value : %s\", text)\n",
    "    return fp\n",
    "        \n",
    "def cleanup_cachedir():\n",
    "    if os.path.exists('cache/'):\n",
    "        shutil.rmtree('cache/')\n",
    "    logging.info(\"Cleaned up cache directory !\")\n",
    "    \n",
    "def plot_and_save_spec(spectrogram, i, save_dir=None):\n",
    "    fig, ax = plt.subplots(figsize=(12, 3))\n",
    "    im = ax.imshow(spectrogram, aspect=\"auto\", origin=\"lower\", interpolation='none')\n",
    "    plt.colorbar(im, ax=ax)\n",
    "    plt.xlabel(\"Frames\")\n",
    "    plt.ylabel(\"Channels\")\n",
    "    plt.tight_layout()\n",
    "    save_file = f\"spec_{i}.png\"\n",
    "    if save_dir:\n",
    "        save_file = os.path.join(save_dir, save_file)\n",
    "    plt.savefig(save_file)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initializing the inference DAG\n",
    "\n",
    "To initialize the graph, we initialize with random text. Later, we will accept the actual text that we want to convert to speech !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = update_text(\"sample text\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create inference DAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tacotron 2 DAG\n",
    "(_, text_embedding, t2_enc, t2_dec, t2_postnet, _, _) = neural_modules\n",
    "\n",
    "data_layer = SentenceDataLayer(\n",
    "    path=filepath,\n",
    "    labels=labels,\n",
    "    batch_size=1,\n",
    "    num_workers=0,\n",
    "    bos_id=len(labels),\n",
    "    eos_id=len(labels) + 1,\n",
    "    pad_id=len(labels) + 2,\n",
    "    shuffle=False,\n",
    ")\n",
    "transcript, transcript_len = data_layer()\n",
    "\n",
    "transcript_embedded = text_embedding(char_phone=transcript)\n",
    "\n",
    "transcript_encoded = t2_enc(char_phone_embeddings=transcript_embedded, embedding_length=transcript_len,)\n",
    "\n",
    "mel_decoder, gate, alignments, mel_len = t2_dec(\n",
    "    char_phone_encoded=transcript_encoded, encoded_length=transcript_len,\n",
    ")\n",
    "\n",
    "mel_postnet = t2_postnet(mel_input=mel_decoder)\n",
    "\n",
    "# WaveGlow DAG\n",
    "audio_pred = waveglow(mel_spectrogram=mel_postnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup inference tensors\n",
    "infer_tensors = [mel_postnet, gate, alignments, mel_len]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference DAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_tacotron2():\n",
    "    logging.info(\"Running Tacotron 2\")\n",
    "    # Run tacotron 2\n",
    "    evaluated_tensors = neural_factory.infer(\n",
    "        tensors=infer_tensors, offload_to_cpu=False\n",
    "    )\n",
    "    logging.info(\"Done Running Tacotron 2\")\n",
    "    \n",
    "    mel_len_val = evaluated_tensors[-1]\n",
    "    \n",
    "    filterbank = librosa.filters.mel(\n",
    "        sr=tacotron2_config[\"sample_rate\"],\n",
    "        n_fft=tacotron2_config[\"n_fft\"],\n",
    "        n_mels=tacotron2_config[\"n_mels\"],\n",
    "        fmax=tacotron2_config[\"fmax\"],\n",
    "    )\n",
    "    \n",
    "    return evaluated_tensors, filterbank, mel_len_val\n",
    "\n",
    "def run_waveglow(save_dir, waveglow_denoiser_strength=0.0):\n",
    "    # Run Tacotron 2 and WaveGlow\n",
    "    evaluated_tensors, filterbank, mel_len_val = run_tacotron2()\n",
    "    \n",
    "    logging.info(\"Running Waveglow\")\n",
    "    evaluated_tensors = neural_factory.infer(\n",
    "        tensors=[audio_pred],\n",
    "    )\n",
    "    logging.info(\"Done Running Waveglow\")\n",
    "    \n",
    "    if waveglow_denoiser_strength > 0:\n",
    "        logging.info(\"Setup WaveGlow denoiser\")\n",
    "        waveglow.setup_denoiser()\n",
    "    \n",
    "    logging.info(\"Saving results to disk\")\n",
    "    for i, batch in enumerate(evaluated_tensors[0]):\n",
    "        audio = batch.cpu().numpy()\n",
    "        for j, sample in enumerate(audio):\n",
    "            sample_len = mel_len_val[i][j] * tacotron2_config[\"n_stride\"]\n",
    "            sample = sample[:sample_len]\n",
    "            save_file = f\"sample_{i * 32 + j}.wav\"\n",
    "            if save_dir:\n",
    "                save_file = os.path.join(save_dir, save_file)\n",
    "            if waveglow_denoiser_strength > 0:\n",
    "                sample, spec = waveglow.denoise(sample, strength=waveglow_denoiser_strength)\n",
    "            else:\n",
    "                spec, _ = librosa.core.magphase(librosa.core.stft(sample, n_fft=waveglow_config[\"n_fft\"]))\n",
    "            write(save_file, waveglow_config[\"sample_rate\"], sample)\n",
    "            spec = np.dot(filterbank, spec)\n",
    "            spec = np.log(np.clip(spec, a_min=1e-5, a_max=None))\n",
    "            plot_and_save_spec(spec, i * 32 + j, save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Tacotron 2 + WaveGlow on input text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = input('Please enter some initial text here :')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = update_text(text)\n",
    "data_layer.update_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare directories to save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = 'results/'\n",
    "saved_audio = os.path.join(savedir, 'sample_0.wav')\n",
    "saved_spectrogram = os.path.join(savedir, 'spec_0.png')\n",
    "\n",
    "if not os.path.exists(savedir):\n",
    "    os.makedirs(savedir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate the audio\n",
    "\n",
    "Lets run the Tacotron 2 model and send the results to WaveGlow to generate the audio!\n",
    "If there is subtle white noise, try increasing the denoiser strength a little at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_waveglow(savedir, waveglow_denoiser_strength=0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lets hear the generated audio !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ipd.Audio(saved_audio, rate=16000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ipd.Image(saved_spectrogram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cleanup cachedir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_cachedir()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
